/**
 * Copyright 2024-2024 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef FRAMEWORK_HCL_HIAI_ND_TENSOR_DESC_H
#define FRAMEWORK_HCL_HIAI_ND_TENSOR_DESC_H

#include <stddef.h>
#include <stdint.h>

#include "hiai_types.h"
#include "hiai_base_types.h"
#include "c/hiai_error_types.h"
#include "c/hiai_c_api_export.h"

#ifdef __cplusplus
extern "C" {
#endif
/**
 * @brief 根据tensor的维度信息、数据类型和数据排布格式创建HIAI_NDTensorDesc指针实例。
 *
 * 本方法创建出的{@link HIAI_NDTensorDesc}指针实例，可用于创建{@link HIAI_MR_NDTensorBuffer}指针实例，申请推理时所需的内存。
 * {@link HIAI_NDTensorDesc}不使用时，使用{@link HIAI_NDTensorDesc_Destroy}进行释放，否则将造成内存泄漏。
 *
 * @return 成功时返回指向{@link HIAI_NDTensorDesc}指针实例，失败时返回空指针
 * @see HIAI_NDTensorDesc_Destroy
 */
AICP_C_API_EXPORT HIAI_NDTensorDesc* HIAI_NDTensorDesc_CreateDefault(void);

/**
 * @brief 销毁HIAI_NDTensorDesc指针实例
 *
 * 销毁由{@link HIAI_NDTensorDesc_CreateDefault}创建的{@link HIAI_NDTensorDesc}指针实例，传入指针的引用即可。若不调用，将造成内存泄漏。
 * 当tensorDesc或者*tensorDesc为空指针时，直接返回，不执行任何操作。
 *
 * @param [in] tensorDesc {@link HIAI_NDTensorDesc}指针实例的引用
 */
AICP_C_API_EXPORT void HIAI_NDTensorDesc_Destroy(HIAI_NDTensorDesc** tensorDesc);

/**
 * @brief 设置输入输出tensor的描述信息中指定维度的值
 *
 * @param [in] tensorDesc {@link HIAI_NDTensorDesc}指针实例，非空，否则返回0
 * @param [in] dims 每个维度的数据集
 * @param [in] dimNum 维度数据集的个数
 * @return 成功到时对应index位置的维度值，失败时返回0
 */
AICP_C_API_EXPORT HIAI_Status HIAI_NDTensorDesc_SetDims(
    HIAI_NDTensorDesc* tensorDesc, const int32_t* dims, size_t dimNum);

/**
 * @brief 查询输入输出tensor的描述信息中的维度数量
 *
 * @param [in] tensorDesc {@link HIAI_NDTensorDesc}指针实例，非空，否则返回0
 * @return 成功时返回描述信息中维度的数量，失败时返回0
 */
AICP_C_API_EXPORT size_t HIAI_NDTensorDesc_GetDimNum(const HIAI_NDTensorDesc* tensorDesc);

/**
 * @brief 查询输入输出tensor的描述信息中指定维度的值
 *
 * @param [in] tensorDesc {@link HIAI_NDTensorDesc}指针实例，非空，否则返回0
 * @param [in] index 维度的位置坐标，从0开始计算
 * @return 成功到时对应index位置的维度值，失败时返回0
 */
AICP_C_API_EXPORT int32_t HIAI_NDTensorDesc_GetDim(const HIAI_NDTensorDesc* tensorDesc, size_t index);

/**
 * @brief 设置输入输出tensor的描述信息中数据的类型
 *
 * @param [in] tensorDesc {@link HIAI_NDTensorDesc}指针实例，非空，否则返回失败
 * @param [in] dataType 对应数据类型{@link HiAI_DataType}
 * @return 成功到时HIAI_SUCCESS 失败时返回HIAI_FAILURE
 */
AICP_C_API_EXPORT HIAI_Status HIAI_NDTensorDesc_SetDataType(HIAI_NDTensorDesc* tensorDesc, HIAI_DataType dataType);

/**
 * @brief 查询输入输出tensor的描述信息中数据的类型
 *
 * @param [in] tensorDesc {@link HIAI_NDTensorDesc}指针实例，非空，否则返回失败
 * @return 成功到时返回对应数据类型{@link HiAI_DataType}，失败时返回{@link HiAI_DataType_UINT8}
 */
AICP_C_API_EXPORT HIAI_DataType HIAI_NDTensorDesc_GetDataType(const HIAI_NDTensorDesc* tensorDesc);

/**
 * @brief 设置输入输出tensor的描述信息中数据排布格式
 *
 * @param [in] tensorDesc {@link HIAI_NDTensorDesc}指针实例，非空，否则返回失败
 * @param [in] format {@link HIAI_Format} 数据排布格式
 * @return 成功时返回对应数据排布格式的{@link HiAI_Format}，失败时返回{@link HiAI_Format_NCHW}
 */
AICP_C_API_EXPORT HIAI_Status HIAI_NDTensorDesc_SetFormat(HIAI_NDTensorDesc* tensorDesc, HIAI_Format format);

/**
 * @brief 查询输入输出tensor的描述信息中数据排布格式
 *
 * @param [in] tensorDesc {@link HIAI_NDTensorDesc}指针实例，非空，否则返回失败
 * @return 成功时返回对应数据排布格式的{@link HiAI_Format}，失败时返回{@link HiAI_Format_NCHW}
 */
AICP_C_API_EXPORT HIAI_Format HIAI_NDTensorDesc_GetFormat(const HIAI_NDTensorDesc* tensorDesc);

/**
 * @brief 查询输入输出tensor的描述信息中数据元素个数
 *
 * @param [in] tensorDesc {@link HIAI_NDTensorDesc}指针实例，非空，否则返回失败
 * @return 成功时返回元素个数，失败时返回0
 */
AICP_C_API_EXPORT size_t HIAI_NDTensorDesc_GetElementNum(const HIAI_NDTensorDesc* tensorDesc);

/**
 * @brief 查询输入输出tensor的描述信息中数据的大小
 *
 * 本方法返回的是所有维度的与数据类型大小的乘积，例如维度为4位，数据类型为float，则返回的值为n*c*h*w*sizeof(Float32)
 *
 * @param [in] tensorDesc {@link HIAI_NDTensorDesc}指针实例，非空，否则返回0
 * @return 成功时返回计算出的结果，失败返回0
 */
AICP_C_API_EXPORT size_t HIAI_NDTensorDesc_GetByteSize(const HIAI_NDTensorDesc* tensorDesc);

#ifdef __cplusplus
}
#endif

#endif // FRAMEWORK_HCL_HIAI_ND_TENSOR_DESC_H
